{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "-nnall6G_S_H"
      },
      "source": [
        "Copyright 2023 Google LLC. SPDX-License-Identifier: Apache-2.0\n",
        "\n",
        "Licensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with the License. You may obtain a copy of the License at\n",
        "\n",
        "https://www.apache.org/licenses/LICENSE-2.0\n",
        "\n",
        "Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "BC8NK7QSt7DJ"
      },
      "source": [
        "# Install packages"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "3x_emgqK0d7P"
      },
      "outputs": [],
      "source": [
        "!pip install openai==0.28\n",
        "!pip install tiktoken\n",
        "!pip install tqdm\n",
        "!pip install matplotlib"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "a1eJJ9HU5rSV"
      },
      "source": [
        "# Import packages"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "GIRBhmE55lDN"
      },
      "outputs": [],
      "source": [
        "import numpy as np\n",
        "import openai\n",
        "import tiktoken\n",
        "from tqdm.auto import trange, tqdm\n",
        "import time\n",
        "import os\n",
        "import json\n",
        "from tqdm import tqdm\n",
        "import re\n",
        "from types import NoneType\n",
        "import multiprocessing.dummy\n",
        "from io import StringIO\n",
        "from contextlib import redirect_stdout\n",
        "import signal\n",
        "from contextlib import contextmanager\n",
        "import matplotlib.pyplot as plt\n",
        "import sys\n",
        "import ast\n",
        "import copy"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "amX1hDXoEfxI"
      },
      "source": [
        "# Set up API key"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "jWwegfs4Emb6"
      },
      "outputs": [],
      "source": [
        "openai.api_key = \"\u003cadd your API key here\u003e\""
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "OY73QDgTUKiN"
      },
      "source": [
        "# Global variables"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "3F_Kl5jyUItW"
      },
      "outputs": [],
      "source": [
        "ENGINE = 'gpt-3.5-turbo-instruct'\n",
        "CORRECT_ANSWER = '52'\n",
        "ANSWER_TOKEN = 'Answer: '\n",
        "CODE_START_TOKEN = \"# CODE START\"\n",
        "CODE_END_TOKEN = \"# CODE END\"\n",
        "MAX_TOKENS = 4096\n",
        "ENCODER = tiktoken.encoding_for_model(ENGINE)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ERhPDLg_T1Ll"
      },
      "source": [
        "# Helper functions"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "vreUkUlPoh3N"
      },
      "outputs": [],
      "source": [
        "def query_llm(prompt, max_tokens, stop=None, temperature=0):\n",
        "  assert type(prompt)\n",
        "  response = openai.Completion.create(prompt=prompt, model=ENGINE, max_tokens=max_tokens, temperature=temperature, stop=stop)\n",
        "  response_text = response.choices[0][\"text\"].strip()\n",
        "  return response_text"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "cqgzkRu7YEi4"
      },
      "outputs": [],
      "source": [
        "def print_result(method, response, answer):\n",
        "  print(\"#### Method ####\")\n",
        "  print(method)\n",
        "  print(\"#### Full Response ####\")\n",
        "  print(response)\n",
        "  print(\"#### Model Answer ####\")\n",
        "  print(answer)\n",
        "  print(\"#### Correct Answer ####\")\n",
        "  print(CORRECT_ANSWER)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "siMv0ercUpbn"
      },
      "outputs": [],
      "source": [
        "errors = {}\n",
        "error_lineno = None\n",
        "lines = None\n",
        "trace_lines = []\n",
        "last_state = None\n",
        "\n",
        "def get_delta_state(state, last_state):\n",
        "  delta_state = {}\n",
        "  for key, val in state.items():\n",
        "    if key not in last_state or val != last_state[key]:\n",
        "      delta_state[key] = val\n",
        "  return delta_state\n",
        "\n",
        "def get_state(frame):\n",
        "  state = {}\n",
        "  for key, item in frame.f_locals.items():\n",
        "    if isinstance(item, (bool, str, int, float, tuple, list, set, dict, NoneType)):\n",
        "      state[key] = item\n",
        "  return state\n",
        "\n",
        "def show_trace(frame, event, arg):\n",
        "  # Declare these global variable first\n",
        "  global errors\n",
        "  global error_lineno\n",
        "  global lines\n",
        "  global trace_lines\n",
        "  global last_state\n",
        "  global lines_run_history\n",
        "\n",
        "  # The LLM-generated code will be wrapped around in the get_answer function call.\n",
        "  # If we don't filter by \"get_answer\", we got a bunch of random exception from colab\n",
        "  if frame.f_code.co_name != \"get_answer\":\n",
        "    return\n",
        "\n",
        "  lineno = frame.f_lineno - 1\n",
        "  # Running a certain line\n",
        "  if event == \"line\":\n",
        "    current_line = lines[lineno]\n",
        "    if current_line.strip() in [\"try:\", \"except:\", \"pass\"]:\n",
        "      pass\n",
        "    elif current_line.strip() == \"return answer\":\n",
        "      assert lineno == len(lines) - 2, \"return answer is at the wrong line\" # Second to last line\n",
        "      state = get_state(frame)\n",
        "      assert last_state is not None\n",
        "      delta_state = get_delta_state(state, last_state)\n",
        "      trace_lines.append(f\"delta state: {delta_state}\")\n",
        "      # Append the final state\n",
        "      trace_lines.append(f\"final state: {state}\")\n",
        "    elif lineno not in errors:\n",
        "      # We previous indent 2 spaces\n",
        "      assert current_line[:2] == \"  \", f\"Python: actual line to run doesn't have two leading spaces: {current_line} {lines}\"\n",
        "      # Now we revert back\n",
        "      current_line = current_line[2:]\n",
        "\n",
        "      state = get_state(frame)\n",
        "      delta_state = None\n",
        "      if last_state is None:\n",
        "        delta_state = None\n",
        "      else:\n",
        "        delta_state = get_delta_state(state, last_state)\n",
        "      last_state = copy.deepcopy(state)\n",
        "\n",
        "      if delta_state is None:\n",
        "        trace_lines.append(\"state: {}\")\n",
        "      else:\n",
        "        trace_lines.append(f\"delta state: {delta_state}\")\n",
        "      trace_lines.append(f\"line: {current_line}\")\n",
        "    else:\n",
        "      # We previous indent 4 spaces\n",
        "      assert current_line[:4] == \"    \", f\"LLM: actual line to run doesn't have four leading spaces: {current_line} {lines}\"\n",
        "      # Now we revert back\n",
        "      current_line = current_line[4:]\n",
        "      # When LLM excutes, remove any trailing space at the beginning\n",
        "\n",
        "      state = get_state(frame)\n",
        "      delta_state = None\n",
        "      if last_state is None:\n",
        "        delta_state = None\n",
        "      else:\n",
        "        delta_state = get_delta_state(state, last_state)\n",
        "      last_state = copy.deepcopy(state)\n",
        "\n",
        "      if delta_state is None:\n",
        "        trace_lines.append(\"state: {}\")\n",
        "      else:\n",
        "        trace_lines.append(f\"delta state: {delta_state}\")\n",
        "      trace_lines.append(f\"line: {current_line}\")\n",
        "\n",
        "      # Due to the context length constraint, only feed in the last three lines of the trace.\n",
        "      prompt = coc_trace_prompt + \"\\n\" + \"\\n\".join(trace_lines[-3:]) + \"\\n\" + \"delta state:\"\n",
        "\n",
        "      token_length = len(ENCODER.encode(prompt))\n",
        "\n",
        "      llm_result = query_llm(prompt, max_tokens=32, stop=[\"\\nline:\"])\n",
        "\n",
        "      progress_bar.update()\n",
        "      program_state_str = llm_result.strip()\n",
        "      try:\n",
        "        new_program_state = ast.literal_eval(program_state_str)\n",
        "        assert isinstance(new_program_state, dict), \"new program state is not a valid dict\"\n",
        "        # Actually update the local variables with the new program state\n",
        "        frame.f_locals.update(new_program_state)\n",
        "      except Exception as e:\n",
        "        raise e\n",
        "\n",
        "  elif event == \"exception\":\n",
        "    # Only capture the lowest level exception AND if this exception hasn't been \"fixed\" before, i.e. this line hasn't be sandwiched by try/except yet.\n",
        "    if error_lineno is None and lineno not in errors:\n",
        "      error_lineno = lineno\n",
        "\n",
        "  return show_trace\n",
        "\n",
        "sys.settrace(show_trace)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "3Wy6msNFcM4K"
      },
      "source": [
        "# Prompts"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "AVnghvmRUW8a"
      },
      "outputs": [],
      "source": [
        "direct_prompt = \"\"\"\n",
        "Q: How many countries have I been to? I’ve been to Bilbao, Death Valley, Paris, Honolulu, Skye.\n",
        "Answer: 4\n",
        "\"\"\".strip()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "rXNAEpuRVW03"
      },
      "outputs": [],
      "source": [
        "cot_prompt = \"\"\"\n",
        "Q: How many countries have I been to? I’ve been to Bilbao, Death Valley, Paris, Honolulu, Skye.\n",
        "A:\n",
        "We'll group by countries and count:\n",
        "1. Spain: Bilbao\n",
        "2. USA: Death Valley, Honolulu\n",
        "3. France: Paris\n",
        "4. UK: Skye\n",
        "There are 4 countries in total. So the answer is 4.\n",
        "Answer: 4\n",
        "\"\"\".strip()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "INFt3mHRV84K"
      },
      "outputs": [],
      "source": [
        "coc_prompt = \"\"\"\n",
        "Q: How many countries have I been to? I’ve been to Bilbao, Death Valley, Paris, Honolulu, Skye.\n",
        "A:\n",
        "# CODE START\n",
        "places = [\"Bilbao\", \"Death Valley\", \"Paris\", \"Honolulu\", \"Skye\"]\n",
        "countries = set()\n",
        "for place in places:\n",
        "  country = get_country(place)\n",
        "  countries.add(country)\n",
        "answer = len(countries)\n",
        "# CODE END\n",
        "Answer: 4\n",
        "\"\"\".strip()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "U1ROiewDi2IS"
      },
      "outputs": [],
      "source": [
        "coc_trace_prompt = \"\"\"\n",
        "# TRACE START\n",
        "state: {}\n",
        "line: places = [\"Bilbao\", \"Death Valley\", \"Paris\", \"Honolulu\", \"Skye\"]\n",
        "delta state: {'places': ['Bilbao', 'Death Valley', 'Paris', 'Honolulu', 'Skye']}\n",
        "line: countries = set()\n",
        "delta state: {'countries': set()}\n",
        "line: for place in places:\n",
        "delta state: {'place': 'Bilbao'}\n",
        "line:   country = get_country(place)\n",
        "delta state: {'country': 'Spain'}\n",
        "line:   countries.add(country)\n",
        "delta state: {'countries': {'Spain'}}\n",
        "line: for place in places:\n",
        "delta state: {'place': 'Death Valley'}\n",
        "line:   country = get_country(place)\n",
        "delta state: {'country': 'USA'}\n",
        "line:   countries.add(country)\n",
        "delta state: {'countries': {'Spain', 'USA'}}\n",
        "line: for place in places:\n",
        "delta state: {'place': 'Paris'}\n",
        "line:   country = get_country(place)\n",
        "delta state: {'country': 'France'}\n",
        "line:   countries.add(country)\n",
        "delta state: {'countries': {'Spain', 'USA', 'France'}}\n",
        "line: for place in places:\n",
        "delta state: {'place': 'Honolulu'}\n",
        "line:   country = get_country(place)\n",
        "delta state: {'country': 'USA'}\n",
        "line:   countries.add(country)\n",
        "delta state: {'countries': {}}\n",
        "line: for place in places:\n",
        "delta state: {'place': 'Skye'}\n",
        "line:   country = get_country(place)\n",
        "delta state: {'country': 'UK'}\n",
        "line:   countries.add(country)\n",
        "delta state: {'countries': {'Spain', 'USA', 'France', 'UK'}}\n",
        "line: answer = len(countries)\n",
        "delta state: {'answer': 4}\n",
        "# TRACE END\n",
        "\n",
        "# TRACE START\n",
        "\"\"\".strip()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "042YPaL1WQW7"
      },
      "outputs": [],
      "source": [
        "query = \"\"\"\n",
        "Q: How many countries have I been to? I’ve been to Mumbai, London, Washington, Grand Canyon, Baltimore, Longsheng, Guilin, Beijing,\n",
        "Galapagos, Quito, Barcelona, Paris, Prague, Nice, Dehli, Agra, Rome, Florence, Amalfi, Athens, Míkonos, Málaga, Monaco, Berlin,\n",
        "Munich, Innsbruck, Bern, Milan, Lucerne, Gimmelwald (Schilthornbahn), St Moritz, St Petersburg, Helsinki, Amsterdam, Gdansk,\n",
        "Vancouver, Anchorage, Montreal, Belize, The Bahamas, Jamaica, Hawaii, Acadia National Park, Stockholm, Copenhagen, Dover, Lyon,\n",
        "Madrid, Toulouse, Santorini, Oslo, Kusadasi, Souda, Rhodes, Tallinn, Venice, Naples, Cape Town, Johannesburg, Addis Abeba,\n",
        "Nairobi, Seattle, San Francisco, Chicago, St Louis, Memphis, Chinle, Stanford, New York, Philadelphia, Boston, Miami,\n",
        "New Orleans, Walt Disney World Resort, Jacksonville, Las Vegas, Los Angeles, Portland, Salt Lake City, Tahoe City, Phoenix,\n",
        "Albuquerque, Cleveland, Charlottesville, Nags Head, Newfoundland and Labrador, Burlington, Wilmington, Myrtle Beach, St Lucia,\n",
        "Barbados, Banff, Haiti, Montego Bay, Sao Palo, Rio, Lima, Cusco, Cozumel, Amarillo, Yosemite National Park, Joshua Tree,\n",
        "Zion National Park, Bryce Canyon National Park, Grand Teton National Park, Yellowstone National Park, Glacier National Park, Mount Hood,\n",
        "Paso Robles, San Diego, Bend, North Cascades National Park, Olympic National Park Visitor Center, Jasper National Park,\n",
        "Sequoia National Park, Kings Canyon National Park, Shasta National Forest, Mount Saint Helens, Mount Rainier, Austin, Buenos Aires,\n",
        "El Calafate, El Chaltén, Fitz Roy, Torres del Paine National Park, Puerto Natales, Puerto Varas, Santiago, Marble Caves, Cerro Castillo,\n",
        "Coyhaique, Singapore, Casablanca, Marrakesh, Cairo, Jerusalem, Tokyo, Kyoto Prefecture, Taipei City, Taichung City, Krk,\n",
        "Naturpark Puez-Geisler, Ljubljana, Plitvice Lakes National Park, Fairbanks, Juneau, Dallas, Sydney, Cairns, Brisbane, Hook Island,\n",
        "Charleston, Panama City, Bangkok, Chiang Mai, Bengaluru, Denver, Indianapolis, Nashville, Blacksburg, Lisbon, Porto, Estes Park,\n",
        "Coeur d’Alene, Hood River, Denali, Sitka, Mexico City, Warsaw, Geneva, Auckland, Queenstown, Whitefish, Minneapolis, Sioux Falls,\n",
        "Bozeman, Missoula, Springfield, Skye, Edinburgh, Honolulu, Kauai, Haleakal¯a National Park, Wrangell-St. Elias National Park \u0026\n",
        "Preserve, Atlanta, Tirana, Corfu, Siena.\n",
        "\"\"\".strip()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "tFCUPor6cPXE"
      },
      "source": [
        "# Demos\n",
        "\n",
        "Note: running these demos will cost around $0.2 for calling OpenAI APIs."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "G-_WuG9ncdzS"
      },
      "source": [
        "## Direct answer prompting"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "QPMsOLqSXM7D"
      },
      "outputs": [],
      "source": [
        "def evaluate_direct(prompt, query):\n",
        "  direct_response = query_llm(prompt + \"\\n\\n\" + query, max_tokens=32)\n",
        "  if ANSWER_TOKEN in direct_response:\n",
        "    direct_answer = direct_response.split(ANSWER_TOKEN)[1].strip()\n",
        "  else:\n",
        "    direct_answer = direct_response\n",
        "  print_result(\"Direct\", direct_response, direct_answer)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ew6LOmS3ccmq"
      },
      "outputs": [],
      "source": [
        "evaluate_direct(direct_prompt, query)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "_q4ccJIjc3h1"
      },
      "source": [
        "## CoT prompting"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "3EijvUBhYqQN"
      },
      "outputs": [],
      "source": [
        "def evaluate_cot(prompt, query):\n",
        "  cot_response = query_llm(prompt + \"\\n\\n\" + query, max_tokens=3072)\n",
        "  if ANSWER_TOKEN in cot_response:\n",
        "    cot_answer = cot_response.split(ANSWER_TOKEN)[1].strip()\n",
        "  else:\n",
        "    cot_answer = cot_response\n",
        "  print_result(\"CoT\", cot_response, cot_answer)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "kSm-Nn_2c5zq"
      },
      "outputs": [],
      "source": [
        "evaluate_cot(cot_prompt, query)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "KBOPxUtvc853"
      },
      "source": [
        "## CoC prompting"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "jM_nYY2ubXCX"
      },
      "outputs": [],
      "source": [
        "def evaluate_coc(prompt, query):\n",
        "  global errors\n",
        "  global error_lineno\n",
        "  global lines\n",
        "  global trace_lines\n",
        "  global last_state\n",
        "  coc_response = query_llm(prompt + \"\\n\\n\" + query, max_tokens=1024)\n",
        "  code_to_run = coc_response.split(CODE_START_TOKEN)[1].split(CODE_END_TOKEN)[0].strip()\n",
        "\n",
        "  answer = None\n",
        "  max_trials = 20\n",
        "  # Wrap the code inside the get_answer function call\n",
        "  code_to_run_temp = code_to_run.split(\"\\n\")\n",
        "  code_to_run = \"\\n\".join([\"  \" + l for l in code_to_run_temp])\n",
        "  code_to_run = f\"\"\"def get_answer():\n",
        "{code_to_run}\n",
        "  return answer\n",
        "answer = get_answer()\"\"\"\n",
        "  lines = code_to_run.split(\"\\n\")\n",
        "  local_vars = locals()\n",
        "\n",
        "  for num_trial in range(max_trials):\n",
        "    if sys.gettrace() is None: sys.settrace(show_trace)\n",
        "    assert sys.gettrace() is not None, \"get trace is None\"\n",
        "    try:\n",
        "      # answer will be populated by exec function.\n",
        "      exec(code_to_run, globals(), local_vars)\n",
        "      coc_answer = local_vars[\"answer\"]\n",
        "      assert coc_answer is not None\n",
        "      break\n",
        "    except Exception as e:\n",
        "      assert error_lineno is not None\n",
        "      # Update errors\n",
        "      line = lines[error_lineno]\n",
        "      errors[error_lineno + 1] = line\n",
        "\n",
        "      # Update lines and code_to_run\n",
        "      num_indent = len(line) - len(line.lstrip())\n",
        "      lines[error_lineno] = \" \" * 2 + lines[error_lineno]\n",
        "      lines.insert(error_lineno, \" \" * num_indent + \"try:\")\n",
        "      lines.insert(error_lineno + 2, \" \" * num_indent + \"except:\")\n",
        "      lines.insert(error_lineno + 3, \" \" * (num_indent + 2) + \"pass\")\n",
        "      code_to_run = \"\\n\".join(lines)\n",
        "\n",
        "      # Reset error_lineno and trace_lines\n",
        "      error_lineno = None\n",
        "      trace_lines = []\n",
        "      last_state = None\n",
        "\n",
        "  print_result('CoC', coc_response, coc_answer)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "AvJ0COBcdLGs"
      },
      "outputs": [],
      "source": [
        "# This cell runs for roughly one minute.\n",
        "NUM_PLACES = 188\n",
        "progress_bar = tqdm(total=NUM_PLACES)\n",
        "evaluate_coc(coc_prompt, query)\n",
        "progress_bar.close()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "sDoi7ddSef88"
      },
      "source": [
        "# Interactive Demos"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "i7Ik1XVFe-Jw"
      },
      "source": [
        "## Code Generation"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "wX0GtQccefgT"
      },
      "outputs": [],
      "source": [
        "code_generation_prompt = \"\"\"Q: Which sentence has the correct adjective order:\n",
        "Options:\n",
        "(A) rubber terrible ship\n",
        "(B) terrible rubber ship\n",
        "A:\n",
        "# CODE START\n",
        "import numpy as np\n",
        "options = {\"(A)\": \"rubber terrible ship\", \"(B)\": \"terrible rubber ship\"}\n",
        "priority = {\"opinion\": 1, \"size\": 2, \"age\": 3, \"shape\": 4, \"color\": 5, \"origin\": 6, \"material\": 7, \"purpose\": 8}\n",
        "valid_types = list(priority.keys())\n",
        "scores = []\n",
        "for option, sentence in options.items():\n",
        "  adjectives = sentence.split(\" \")[:-1]\n",
        "  order = [priority[get_adjective_type(adjective, valid_types, ret_type=str)] for adjective in adjectives]\n",
        "  scores.append([order[i+1] \u003e order[i] for i in range(len(order) - 1)].count(True))\n",
        "answer = list(options.keys())[np.argmax(scores)]\n",
        "# CODE END\n",
        "\n",
        "Q: Today is Christmas Eve of 1937. What is the date 10 days ago in MM/DD/YYYY?\n",
        "A:\n",
        "# CODE START\n",
        "import datetime\n",
        "options = {\"12/14/2026\": \"(A)\", \"12/14/1950\": \"(B)\", \"12/14/2007\": \"(C)\", \"12/14/1937\": \"(D)\", \"07/14/1938\": \"(E)\", \"12/14/1988\": \"(F)\"}\n",
        "today = datetime.date(year=1937, month=12, day=24)\n",
        "date = today - datetime.timedelta(days=10)\n",
        "answer = date.strftime(\"%m/%d/%Y\")\n",
        "# CODE END\n",
        "\n",
        "Q: Recommend a movie similar to Star Wars Episode IV - A New Hope, Indiana Jones and the Last Crusade, Star Wars Episode V - The Empire Strikes Back, The Big Lebowski:\n",
        "A:\n",
        "# CODE START\n",
        "ref_movies = [\"Star Wars Episode IV - A New Hope\", \"Indiana Jones and the Last Crusade\", \"Star Wars Episode V - The Empire Strikes Back\", \"The Big Lebowski\"]\n",
        "ref_movie_infos = get_movie_genre_and_year(ref_movies, ret_type=\"list[tuple[str, int]]\")\n",
        "answer = get_most_similar_movie(ref_movies, ref_movie_infos, ret_type=\"str\")\n",
        "# CODE END\n",
        "\n",
        "Q: \"\"\""
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "0JfB-Fvdfncw"
      },
      "outputs": [],
      "source": [
        "def evaluate_code_generation(prompt, query):\n",
        "  print(query_llm(prompt + query, max_tokens=256, stop=[\"\\n\\nQ:\"]))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "4IhAVzhdjLK7"
      },
      "source": [
        "This interactive demo showcases the code generation aspect of Chain of Code.\n",
        "\n",
        "Aside from the three examples below, feel free to try out any reasoning tasks of your choice by typing in the input box on the right-hand side.\n",
        "Our model should be able to output reasonable python code/pseudocode to solve the task."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "BZfkweRaft4r"
      },
      "outputs": [],
      "source": [
        "example_1 = \"What type of food does two concentric circles look like?\"\n",
        "example_2 = \"If I stack three Eiffel Towers on top of each other, how tall is the new tower?\"\n",
        "example_3 = \"What are 10 smallest countries in the world?\"\n",
        "query = \"What type of food does two concentric circles look like?\" #@param {allow-input: true, type:\"string\"}\n",
        "query += \"\\nA:\\n\"\n",
        "\n",
        "evaluate_code_generation(code_generation_prompt, query)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "VAlgtnc7fATy"
      },
      "source": [
        "## Code Execution (LMulator)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "6Yg7xFq4eWpq"
      },
      "outputs": [],
      "source": [
        "lmulator_prompt = \"\"\"# TRACE START\n",
        "state: {}\n",
        "line: adjective = \"red\"\n",
        "delta state: {'adjective': 'red'}\n",
        "line: valid_types = {\"opinion\", \"size\", \"age\", \"shape\", \"color\", \"origin\", \"material\", \"purpose\"}\n",
        "delta state: {'valid_types': {'opinion', 'size', 'age', 'shape', 'color', 'origin', 'material', 'purpose'}}\n",
        "line: adj_type = get_adjective_type(adjective, valid_types, ret_type=str)\n",
        "delta state: {'adj_type': 'color'}\n",
        "# TRACE END\n",
        "\n",
        "# TRACE START\n",
        "state: {}\n",
        "line: obj1 = \"soda can\"\n",
        "delta state: {'obj1': 'soda can'}\n",
        "line: is_obj1_recyclable = is_recyclable(obj1, ret_type=bool)\n",
        "delta state: {'is_obj1_recyclable': True}\n",
        "line: obj2 = \"fruit\"\n",
        "delta state: {'obj2': 'fruit'}\n",
        "line: is_obj2_recyclable = is_recyclable(obj2, ret_type=bool)\n",
        "delta state: {'is_obj2_recyclable': False}\n",
        "# TRACE END\n",
        "\n",
        "# TRACE START\n",
        "state: {}\n",
        "line: num1 = 23\n",
        "delta state: {'num1': 23}\n",
        "line: num2 = 52\n",
        "delta state: {'num2': 52}\n",
        "line: sum_of_two = num1 + num2\n",
        "delta state: {'sum_of_two': 75}\n",
        "line: greated_than_one_hundred = is_greater(sum_of_two, 100)\n",
        "delta state: {'greated_than_one_hundred': False}\n",
        "# TRACE END\n",
        "\n",
        "# TRACE START\n",
        "state: {}\n",
        "line: \"\"\""
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "-v2D93MafQPJ"
      },
      "outputs": [],
      "source": [
        "def evaluate_lmulator(prompt, query):\n",
        "  print(query_llm(prompt + query, max_tokens=256, stop=[\"\\nline:\", \"\\n# TRACE END\"]))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "wXNtdb0qjyi8"
      },
      "source": [
        "This interactive demo showcases the code execution aspect of Chain of Code (LMulator).\n",
        "\n",
        "Aside from the three examples below, feel free to try out any python code/pseudocode that you want the LMulator to simulate the execution by typing in the input box on the right-hand side.\n",
        "\n",
        "Our model should be able to output a reasonable delta state as the result of simulated code execution."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "X0jRGkUzfMF7"
      },
      "outputs": [],
      "source": [
        "example_1 = \"extinct_animal = which_animal_is_extinct(['panda', 'dinosaur', 'pig'], ret_type=str)\"\n",
        "example_2 = \"divisible_test = is_divisible(divident=2142, divisor=17, ret_type=bool)\"\n",
        "example_3 = \"distance_between_cities = get_distance(city1='San Francisco', city2='New York', unit='kilometers', ret_type=float)\"\n",
        "query = \"extinct_animal = which_animal_is_extinct(['panda', 'dinosaur', 'pig'], ret_type=str)\" #@param {allow-input: true, type:\"string\"}\n",
        "\n",
        "evaluate_lmulator(lmulator_prompt, query)"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "private_outputs": true,
      "provenance": [
        {
          "file_id": "1UGoztJO-KvBg2KCXhvu7lFjlvZIF1H4M",
          "timestamp": 1707175089868
        },
        {
          "file_id": "1G-wp_L3brvDTkthlEryOrpvTYFJISbtt",
          "timestamp": 1707175048199
        },
        {
          "file_id": "https://github.com/google-research/google-research/blob/master/code_as_policies/coc_demo.ipynb",
          "timestamp": 1706132211285
        },
        {
          "file_id": "1Zsdat9J2_G3HFWOfYDFvLNI-zKMoQMTu",
          "timestamp": 1702611985486
        },
        {
          "file_id": "1qjyq0oMXefs6ERvCXRcNkz0PDKBLp3ji",
          "timestamp": 1702611855599
        }
      ]
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
